{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recommender Systems with GNNs\n",
    "\n",
    "The academia and industries have witnessed various attempts and applications of applying GNNs to business applications including recommender systems, fraud detection, etc.\n",
    "\n",
    "In this hands-on tutorial, we will see how GNNs can be used for a recommender system.\n",
    "\n",
    "## Recommender System: Overview\n",
    "\n",
    "Recommender systems are typically used on E-commerce or similar platforms.  It is responsible for predicting the items that are most likely to be interacted by a user, given his/her previous interaction history.\n",
    "\n",
    "Based on whether the interactions have a *rating* that signifies how a user likes an item, we could categorize the interaction data as having either *explicit feedback* if the data has ratings, or *implicit feedback* otherwise.\n",
    "\n",
    "There are a lot of methodologies for recommendation.  Here we will focus on *matrix-factorization-based* methods that expresses users and items as *embeddings* (i.e. vectors of same dimensionality), and the preference of a user to an item as the *dot product* between user and item embeddings.\n",
    "\n",
    "More specifically, for each user we can assign a learnable vector $\\boldsymbol{u}_i$, and for each item we can assign another learnable vector $\\boldsymbol{v}_j$.  Additionally, we also learn a bias for each user $\\beta_i$ and each item $\\gamma_j$, representing the baseline rating of each user and item.  Preference prediction is expressed as $\\hat{r}_{i,j}=\\boldsymbol{u}_i^\\top \\boldsymbol{v}_j + \\beta_i + \\gamma_j$.\n",
    "\n",
    "For explicit feedback datasets, we try to minimize the difference between prediction and ground truth $r_{i,j}$ for all user-item interactions:\n",
    "\n",
    "$$\n",
    "\\min_{\\boldsymbol{u},\\boldsymbol{v}} \\sum_{(i,j)\\in \\mathcal{D}} \\left(\\hat{r}_{i,j} - r_{i,j}\\right)^2\n",
    "$$\n",
    "\n",
    "For implicit feedback datasets, the function to minimize is more complicated, so in this tutorial we will focus on explicit feedback datasets.\n",
    "\n",
    "Other recommender system families include neighborhood-based models, factorization machines, etc.  But we won't go through them here.\n",
    "\n",
    "## What Does a GNN Do in Recommender Systems?\n",
    "\n",
    "GNNs most commonly appear as a another method of computing the user and item embeddings: instead of directly assigning a learnable vector as user and item embedding, we compute the embeddings with a GNN.  The graph for the GNN would be the interaction data itself, with users and items as two types of nodes and the interactions as edges connecting the nodes.  \n",
    "\n",
    "Comparing against other approaches, the biggest advantage of GNNs in recommender systems is that it can potentially combine the interactions data with other kind of relational data.  If we have additional relational data such as social networks or knowledge graphs, we can combine those data with the interaction data to form a bigger graph, and still run the same GNN as a recommender system.  Other advantages include ability of introducing the information of neighbors to embeddings, and the ability to combine both user/item features and the neighborhood structure of users/items, etc.\n",
    "\n",
    "In our hands-on, we will use the MovieLens-100K dataset.\n",
    "\n",
    "Note that to simplify stuff in our tutorial,\n",
    "\n",
    "* We ignore the user and item features.\n",
    "* We would not be considering new users and items."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "train_data = pd.read_csv('ua.base', sep='\\t', header=None, names=['user_id', 'item_id', 'rating', 'timestamp'])\n",
    "test_data = pd.read_csv('ua.test', sep='\\t', header=None, names=['user_id', 'item_id', 'rating', 'timestamp'])\n",
    "\n",
    "# Remove the entries with users and items not appearing in the training dataset\n",
    "test_data = test_data[test_data['user_id'].isin(train_data['user_id']) &\n",
    "                      test_data['item_id'].isin(train_data['item_id'])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heterogeneous Graphs in DGL\n",
    "\n",
    "If we treat our interaction data as a graph, we will have two types of nodes: users and items.  We call a graph that has multiple node types and edge types a *heterogeneous* graph.\n",
    "\n",
    "With MovieLens-100K loaded, we will see how to build a heterogeneous graph from it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import torch\n",
    "\n",
    "# Relabel the user IDs and item IDs to integers\n",
    "train_data = train_data.astype({'user_id': 'category', 'item_id': 'category'})\n",
    "test_data = test_data.astype({'user_id': 'category', 'item_id': 'category'})\n",
    "# We need to keep the relabeling between training and test data consistent\n",
    "test_data['user_id'].cat.set_categories(train_data['user_id'].cat.categories, inplace=True)\n",
    "test_data['item_id'].cat.set_categories(train_data['item_id'].cat.categories, inplace=True)\n",
    "\n",
    "train_user_ids = torch.LongTensor(train_data['user_id'].cat.codes.values)\n",
    "train_item_ids = torch.LongTensor(train_data['item_id'].cat.codes.values)\n",
    "train_ratings = torch.LongTensor(train_data['rating'].values)\n",
    "test_user_ids = torch.LongTensor(test_data['user_id'].cat.codes.values)\n",
    "test_item_ids = torch.LongTensor(test_data['item_id'].cat.codes.values)\n",
    "test_ratings = torch.LongTensor(test_data['rating'].values)\n",
    "\n",
    "# Build graph\n",
    "graph = dgl.heterograph({\n",
    "    # Heterogeneous graphs are organized as a dictionary of edges connecting two types of nodes.\n",
    "    # We specify the edges of a type simply with a pair of user ID array and item ID array.\n",
    "    ('user', 'watched', 'item'): (train_user_ids, train_item_ids),\n",
    "    # Since DGL graphs are directional, we need an inverse relation from items to users as well.\n",
    "    ('item', 'watched-by', 'user'): (train_item_ids, train_user_ids)\n",
    "})\n",
    "\n",
    "# Assign ratings\n",
    "graph.edges['watched'].data['rating'] = torch.LongTensor(train_ratings)\n",
    "graph.edges['watched-by'].data['rating'] = torch.LongTensor(train_ratings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model on Heterogeneous Graphs\n",
    "\n",
    "### Define Dataset to Sample Minibatches from\n",
    "\n",
    "The first thing we need to do for GNN training is to define how a minibatch of examples look like.  For rating prediction tasks, a minibatch consists of pairs of users and items.  So we create a torch `Dataset` object that contains all training pairs of users and items as well as their ratings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_dataset = TensorDataset(train_user_ids, train_item_ids, train_ratings)\n",
    "test_dataset = TensorDataset(test_user_ids, test_item_ids, test_ratings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Minibatch & Neighbor Sampler\n",
    "\n",
    "For an interaction we wish to compute the representation of its user and its item with a multi-layer GNN.  Therefore, as in the previous tutorial, we need to define a minibatch sampler that gets us the computation dependency of the user and item nodes for the multi-layer GNN.  At the same time, we also want to keep track of the training interactions themselves, i.e. which user is interacting which item.\n",
    "\n",
    "In DGL, rating prediction could be achieved with the following steps:\n",
    "\n",
    "1. Create a *pair graph* that consists of all the users and items in the dataset, but only with the interactions in the training minibatch.\n",
    "2. *Compact* the graph to remove all unnecessary nodes, keeping only the nodes with some edges going in or out.\n",
    "3. Construct blocks bottom-up for computing the multi-layer output of the nodes, as in the previous tutorial.\n",
    "4. Propagate messages top-down to obtain the output from the GNN.\n",
    "5. Copy the output to the pair graph and compute the preference with `apply_edges()` method.\n",
    "\n",
    "The following cell has the code of completing step 1 to step 3, which constitutes the steps of obtaining computation dependency for the interactions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MinibatchSampler(object):\n",
    "    def __init__(self, graph, num_layers):\n",
    "        self.graph = graph\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "    def sample(self, batch):\n",
    "        # Convert the list of user-item-rating triplets into a triplet of users, items, and ratings\n",
    "        users, items, ratings = zip(*batch)\n",
    "        users = torch.stack(users)\n",
    "        items = torch.stack(items)\n",
    "        ratings = torch.stack(ratings)\n",
    "        \n",
    "        # Create a pair graph (Step 1)\n",
    "        pair_graph = dgl.heterograph({\n",
    "            ('user', 'watched', 'item'): (users, items)\n",
    "        })\n",
    "        \n",
    "        # Compact the graph (Step 2)\n",
    "        pair_graph = dgl.compact_graphs(pair_graph)\n",
    "        # Assign ratings to the graph\n",
    "        pair_graph.edata['rating'] = ratings\n",
    "        \n",
    "        # Construct blocks (Step 3)\n",
    "        seeds = {'user': pair_graph.nodes['user'].data[dgl.NID],\n",
    "                 'item': pair_graph.nodes['item'].data[dgl.NID]}\n",
    "        blocks = []\n",
    "        for i in range(self.num_layers):\n",
    "            # We take all neighbors to form the sampled graph for computing the node representations on the\n",
    "            # current layer.\n",
    "            sampled_graph = dgl.in_subgraph(self.graph, seeds)\n",
    "            # Find the sampled edge IDs for both directions\n",
    "            sampled_eids = sampled_graph.edges['watched'].data[dgl.EID]\n",
    "            sampled_eids_rev = sampled_graph.edges['watched-by'].data[dgl.EID]\n",
    "            \n",
    "            # A subtlety of rating prediction and link prediction is that, when we train on the pair of user A\n",
    "            # and item 1, we don't want to actually tell the GNN that \"user A has a connection to item 1\".  So\n",
    "            # we should remove all edges connecting the training pairs from the sampled graph.\n",
    "            _, _, edges_to_remove = sampled_graph.edge_ids(users, items, etype='watched', force_multi=True)\n",
    "            _, _, edges_to_remove_rev = sampled_graph.edge_ids(items, users, etype='watched-by', force_multi=True)\n",
    "            sampled_with_edges_removed = dgl.remove_edges(\n",
    "                sampled_graph, {'watched': edges_to_remove, 'watched-by': edges_to_remove_rev})\n",
    "            sampled_eids = sampled_eids[sampled_with_edges_removed.edges['watched'].data[dgl.EID]]\n",
    "            sampled_eids_rev = sampled_eids_rev[sampled_with_edges_removed.edges['watched-by'].data[dgl.EID]]\n",
    "            \n",
    "            # Create a block from the sampled graph.\n",
    "            block = dgl.to_block(sampled_with_edges_removed, seeds)\n",
    "            blocks.insert(0, block)\n",
    "            seeds = {'user': block.srcnodes['user'].data[dgl.NID],\n",
    "                     'item': block.srcnodes['item'].data[dgl.NID]}\n",
    "            block.edges['watched'].data['rating'] = \\\n",
    "                self.graph.edges['watched'].data['rating'][sampled_eids]\n",
    "            block.edges['watched-by'].data['rating'] = \\\n",
    "                self.graph.edges['watched-by'].data['rating'][sampled_eids_rev]\n",
    "            \n",
    "        return pair_graph, blocks\n",
    "\n",
    "# In this tutorial we consider 1-layer GNNs.\n",
    "NUM_LAYERS = 1\n",
    "BATCH_SIZE = 500\n",
    "sampler = MinibatchSampler(graph, NUM_LAYERS)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=sampler.sample, shuffle=True)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=sampler.sample, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Model\n",
    "\n",
    "We choose a simplified version of Graph Convolutional Matrix Completion (GCMC) as the model.  Each node sends a message to its neighbors by computing a rating-specific projection of the node's representation.  Then each node gathers the messages and takes an average, projecting it afterwards.\n",
    "\n",
    "More specifically, given a user node $i$ with its input representations $u_i^{l-1}$ on layer $l$, we\n",
    "\n",
    "1. Find all neighbors, i.e. interacted items, of $i$.\n",
    "2. For each item $j$, check the rating $r_{ij}$ between user $i$ and item $j$.  The message sent from item $j$ to user $i$ would be a rating-specific linear projection $m_{j\\to i}^l \\gets W_{r_{ij}}^l v_j^{l-1}$.  $W_{r_{ij}}^l$ is a learnable matrix.\n",
    "3. The user $i$ aggregates all incoming messages by averaging and updates itself: $u_i^l \\gets \\mathrm{ReLU}(W^l[\\mathrm{Average}(m_{j\\to i}^l); u_i^{l-1}])$.  $W^l$ is again a learnable matrix.\n",
    "\n",
    "We can similarly compute $v_j$, the representation of an item node $j$.\n",
    "\n",
    "For an implementation faithful to the paper, please see our model examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import dgl.function as fn\n",
    "\n",
    "class GCMCLayer(nn.Module):\n",
    "    def __init__(self, hidden_dims, num_ratings):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Weights for user nodes\n",
    "        self.W_r = nn.Parameter(torch.randn(num_ratings + 1, hidden_dims, hidden_dims))\n",
    "        self.W = nn.Linear(hidden_dims * 2, hidden_dims)\n",
    "        \n",
    "        # Weights for item nodes\n",
    "        self.V_r = nn.Parameter(torch.randn(num_ratings + 1, hidden_dims, hidden_dims))\n",
    "        self.V = nn.Linear(hidden_dims * 2, hidden_dims)\n",
    "        \n",
    "    def compute_message(self, W, edges):\n",
    "        W_r = W[edges.data['rating']]\n",
    "        h = edges.src['h']\n",
    "        m = (W_r @ h.unsqueeze(-1)).squeeze(2)\n",
    "        return m\n",
    "        \n",
    "    def forward(self, block, input_user_features, input_item_features):\n",
    "        with block.local_scope():\n",
    "            block.srcnodes['user'].data['h'] = input_user_features\n",
    "            block.srcnodes['item'].data['h'] = input_item_features\n",
    "            \n",
    "            block.dstnodes['user'].data['h'] = input_user_features[:block.number_of_dst_nodes('user')]\n",
    "            block.dstnodes['item'].data['h'] = input_item_features[:block.number_of_dst_nodes('item')]\n",
    "            \n",
    "            # Compute messages\n",
    "            block.apply_edges(\n",
    "                lambda edges: {'m': self.compute_message(self.W_r, edges)},\n",
    "                etype='watched')\n",
    "            block.apply_edges(\n",
    "                lambda edges: {'m': self.compute_message(self.V_r, edges)},\n",
    "                etype='watched-by')\n",
    "            \n",
    "            # Aggregate messages\n",
    "            block.update_all(fn.copy_e('m', 'm'), fn.mean('m', 'h_neigh'), etype='watched')\n",
    "            block.update_all(fn.copy_e('m', 'm'), fn.mean('m', 'h_neigh'), etype='watched-by')\n",
    "            \n",
    "            # Updates the representations of output users and items\n",
    "            block.dstnodes['user'].data['h'] = F.relu(\n",
    "                self.W(torch.cat([block.dstnodes['user'].data['h'], block.dstnodes['user'].data['h_neigh']], 1)))\n",
    "            block.dstnodes['item'].data['h'] = F.relu(\n",
    "                self.W(torch.cat([block.dstnodes['item'].data['h'], block.dstnodes['item'].data['h_neigh']], 1)))\n",
    "            \n",
    "            return block.dstnodes['user'].data['h'], block.dstnodes['item'].data['h']\n",
    "        \n",
    "class GCMCRating(nn.Module):\n",
    "    def __init__(self, num_users, num_items, hidden_dims, num_ratings, num_layers):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.user_embeddings = nn.Embedding(num_users, hidden_dims)\n",
    "        self.item_embeddings = nn.Embedding(num_items, hidden_dims)\n",
    "        \n",
    "        self.layers = nn.ModuleList([\n",
    "            GCMCLayer(hidden_dims, num_ratings) for _ in range(num_layers)])\n",
    "        \n",
    "        self.W = nn.Linear(hidden_dims, hidden_dims)\n",
    "        self.V = nn.Linear(hidden_dims, hidden_dims)\n",
    "        \n",
    "    def forward(self, pair_graph, blocks):\n",
    "        # Propagate messages top-down (Step 4)\n",
    "        # We start with a learnable embedding for each user and item.  Then we compute the GNN\n",
    "        # outputs with those learnable embeddings as inputs.\n",
    "        user_embeddings = self.user_embeddings(blocks[0].srcnodes['user'].data[dgl.NID])\n",
    "        item_embeddings = self.item_embeddings(blocks[0].srcnodes['item'].data[dgl.NID])\n",
    "        for block, layer in zip(blocks, self.layers):\n",
    "            user_embeddings, item_embeddings = layer(block, user_embeddings, item_embeddings)\n",
    "        \n",
    "        # Compute predicted preference (Step 5)\n",
    "        user_embeddings = self.W(user_embeddings)\n",
    "        item_embeddings = self.V(item_embeddings)\n",
    "        \n",
    "        with pair_graph.local_scope():\n",
    "            pair_graph.nodes['user'].data['h'] = user_embeddings\n",
    "            pair_graph.nodes['item'].data['h'] = item_embeddings\n",
    "            pair_graph.apply_edges(fn.u_dot_v('h', 'h', 'r'))\n",
    "            \n",
    "            return pair_graph.edata['r']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Evaluation Metric\n",
    "\n",
    "In this tutorial we use root of mean squared error (RMSE) as the metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(pred, label):\n",
    "    return ((pred - label) ** 2).mean().sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop\n",
    "\n",
    "Note: in this tutorial we are directly validating on the test set for simplicity.  In practice we should split out some data points in the training set to form a validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:27<00:00,  6.59it/s, loss=1.6613]\n",
      "100%|██████████| 19/19 [00:01<00:00, 16.21it/s]\n",
      "  1%|          | 1/182 [00:00<00:28,  6.46it/s, loss=1.5311]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.338324785232544\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:43<00:00,  4.23it/s, loss=1.4943]\n",
      "100%|██████████| 19/19 [00:01<00:00, 13.83it/s]\n",
      "  1%|          | 1/182 [00:00<00:33,  5.33it/s, loss=1.2547]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.1455392837524414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:34<00:00,  5.23it/s, loss=1.0911]\n",
      "100%|██████████| 19/19 [00:01<00:00, 14.87it/s]\n",
      "  1%|          | 1/182 [00:00<00:33,  5.38it/s, loss=1.0869]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.0888710021972656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:33<00:00,  5.42it/s, loss=1.2039]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.72it/s]\n",
      "  1%|          | 1/182 [00:00<00:25,  7.08it/s, loss=0.9744]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.0589362382888794\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:35<00:00,  5.06it/s, loss=1.2157]\n",
      "100%|██████████| 19/19 [00:01<00:00, 14.28it/s]\n",
      "  1%|          | 1/182 [00:00<00:31,  5.73it/s, loss=1.0032]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.039975643157959\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:35<00:00,  5.13it/s, loss=0.9786]\n",
      "100%|██████████| 19/19 [00:01<00:00, 14.94it/s]\n",
      "  1%|          | 1/182 [00:00<00:29,  6.06it/s, loss=0.8799]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.02749502658844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:35<00:00,  5.09it/s, loss=1.0473]\n",
      "100%|██████████| 19/19 [00:01<00:00, 14.78it/s]\n",
      "  1%|          | 1/182 [00:00<00:29,  6.12it/s, loss=0.9842]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.0141799449920654\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:38<00:00,  4.69it/s, loss=0.9200]\n",
      "100%|██████████| 19/19 [00:01<00:00, 13.95it/s]\n",
      "  1%|          | 1/182 [00:00<00:31,  5.71it/s, loss=0.8705]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.0038763284683228\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:35<00:00,  5.19it/s, loss=1.0781]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.52it/s]\n",
      "  1%|          | 1/182 [00:00<00:30,  5.90it/s, loss=0.9905]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9980089664459229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:33<00:00,  5.40it/s, loss=0.8775]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.89it/s]\n",
      "  1%|          | 1/182 [00:00<00:25,  7.12it/s, loss=0.8713]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9922261834144592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:27<00:00,  6.64it/s, loss=1.1563]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.43it/s]\n",
      "  1%|          | 1/182 [00:00<00:35,  5.03it/s, loss=0.8117]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9862333536148071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:30<00:00,  6.05it/s, loss=0.8535]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.35it/s]\n",
      "  1%|          | 1/182 [00:00<00:30,  6.03it/s, loss=0.8250]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9824907183647156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:37<00:00,  4.79it/s, loss=0.7394]\n",
      "100%|██████████| 19/19 [00:01<00:00, 15.00it/s]\n",
      "  1%|          | 1/182 [00:00<00:30,  6.00it/s, loss=0.8677]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9791952967643738\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 182/182 [00:38<00:00,  4.76it/s, loss=0.8822]\n",
      "100%|██████████| 19/19 [00:01<00:00, 14.91it/s]\n",
      "  1%|          | 1/182 [00:00<00:31,  5.69it/s, loss=0.8937]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.9768857955932617\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 16/182 [00:04<00:41,  3.96it/s, loss=0.8526]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-e311e3a5d15f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprediction\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mpair_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0medata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'rating'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m             \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m             \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m             \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_postfix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'%.4f'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrefresh\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    193\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    194\u001b[0m         \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "NUM_EPOCHS = 50\n",
    "HIDDEN_DIMS = 8\n",
    "\n",
    "model = GCMCRating(graph.number_of_nodes('user'), graph.number_of_nodes('item'), HIDDEN_DIMS, 5, NUM_LAYERS)\n",
    "opt = torch.optim.Adam(model.parameters())\n",
    "\n",
    "for _ in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    with tqdm.tqdm(train_dataloader) as t:\n",
    "        for pair_graph, blocks in t:\n",
    "            prediction = model(pair_graph, blocks)\n",
    "            loss = ((prediction - pair_graph.edata['rating']) ** 2).mean()\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            t.set_postfix({'loss': '%.4f' % loss.item()}, refresh=False)\n",
    "    model.eval()\n",
    "    with tqdm.tqdm(test_dataloader) as t:\n",
    "        with torch.no_grad():\n",
    "            predictions = []\n",
    "            ratings = []\n",
    "            for pair_graph, blocks in t:\n",
    "                prediction = model(pair_graph, blocks)\n",
    "                predictions.append(prediction)\n",
    "                ratings.append(pair_graph.edata['rating'])\n",
    "\n",
    "            predictions = torch.cat(predictions, 0)\n",
    "            ratings = torch.cat(ratings, 0)\n",
    "            print('RMSE:', rmse(predictions, ratings).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
